import ast
import asyncio
import importlib
import inspect
import sys
from abc import ABC
from functools import lru_cache

import black
import country_converter as coco
from chainlite import chain
from pydantic import BaseModel
from tqdm import tqdm

from cc_news.datasets.common_schema import AbstractEvent
from log_utils import get_logger

logger = get_logger(__name__)


cc = coco.CountryConverter()


@lru_cache
def load_module_from_path(file_path, module_name="schema_definition"):
    # Create a module spec from the file location
    spec = importlib.util.spec_from_file_location(module_name, file_path)
    if spec is None:
        raise ImportError(f"Cannot find module named {module_name} at {file_path}")

    # Create a new module based on the spec
    module = importlib.util.module_from_spec(spec)

    # Add the module to sys.modules
    sys.modules[module_name] = module

    # Execute the module in its own namespace
    spec.loader.exec_module(module)

    return module


@chain
def extract_tag_from_llm_output(llm_output: str, tag: str):
    ret = ""
    tag_start = llm_output.find(f"<{tag}>") + len(f"<{tag}>")
    tag_end = llm_output.find(f"</{tag}>", tag_start)
    if tag_start >= 0 and tag_end >= 0:
        ret = llm_output[tag_start:tag_end].strip()
    if ret.startswith("-"):
        ret = ret[1:].strip()

    return ret


@chain
def string_to_list(llm_output: str):
    ret = llm_output.split("\n")
    for i in range(len(ret)):
        ret[i] = ret[i].strip()
        if ret[i].startswith("-"):
            ret[i] = ret[i].split("-", 1)[1].strip()

    ret = [r for r in ret if r]  # remove empty strings
    return ret


def get_schema_path() -> str:
    return "schema_definition.py"



@lru_cache
def normalize_country_name(country_name: str) -> str:
    ret = cc.convert(names=country_name, to="name_short", not_found="not found")
    return ret


@lru_cache
def country_name_to_code(country_name: str) -> str:
    ret = cc.convert(names=country_name, to="ISO2")
    return ret

def format_python_code(python_code):
    return black.format_str(
        python_code,
        mode=black.Mode(),
    ).strip()

def string_to_event_object(string: str) -> AbstractEvent:
    if not isinstance(string, str):
        return string
    module = load_module_from_path(
        get_schema_path(),
        module_name="schema_definition",
    )
    module_dict = {
        attr: getattr(module, attr) for attr in dir(module) if not attr.startswith("__")
    }

    return eval(string.strip(), {"__builtins__": None}, module_dict)


def is_pydantic_class(class_obj):
    return isinstance(class_obj, type) and issubclass(class_obj, BaseModel)


def get_class_names_in_order(module_path):
    """
    Useful for keeping related event types together
    """
    with open(module_path, "r") as file:
        module_content = file.read()

    tree = ast.parse(module_content)

    classes = []
    for node in ast.walk(tree):
        if isinstance(node, ast.ClassDef):
            classes.append(node.name)

    return classes


def extract_classes_info(include_abstract: bool):
    """
    Extract information about classes defined in a given module, focusing on Event subclasses.


    Returns:
    OrderedDict: A sorted dictionary where keys are class names (strings) of Event subclasses,
                 and values are dictionaries containing:
                 - 'class_definition': A json object for the class and its dependencies.
                 - 'docstring': The docstring of the class.
                 - 'class_obj': The class object
    """
    # Load the module from the given path
    module_path = get_schema_path()
    schema_definition = load_module_from_path(module_path)

    # Get all classes defined in the module
    class_names = get_class_names_in_order(module_path)
    classes = inspect.getmembers(schema_definition, inspect.isclass)
    # Sort the classes based on the order of class names
    classes = sorted(
        classes,
        key=lambda x: (
            class_names.index(x[0]) if x[0] in class_names else len(class_names)
        ),
    )

    # Initialize dictionaries to store event types and class codes
    event_type_to_prompt_parameters = {}

    # Iterate through all classes in the module
    for class_name, class_obj in classes:
        if (
            is_pydantic_class(class_obj)
            and issubclass(class_obj, schema_definition.AbstractEvent)
            and class_obj is not schema_definition.AbstractEvent
            and (
                include_abstract or not ABC in class_obj.__bases__
            )  # is not defined as an abstract class
        ):
            event_type_to_prompt_parameters[class_name] = {
                "class_definition": class_obj.model_json_schema(),
                "docstring": inspect.getdoc(class_obj),
                "class_obj": class_obj,
            }

    return event_type_to_prompt_parameters


async def run_async_in_parallel(
    async_function, iterable, max_concurrency: int, desc: str = ""
):
    semaphore = asyncio.Semaphore(max_concurrency)  # Limit concurrent tasks

    async def async_function_with_semaphore(f, i, original_index) -> tuple:
        # Acquire the semaphore to limit the number of concurrent tasks
        async with semaphore:
            try:
                # Execute the asynchronous function and get the result
                result = await f(i)
                # Return the original index, result, and no error
                return original_index, result, None
            except Exception as e:
                # If an exception occurs, return the original index, no result, and the error message
                logger.exception(f"Task {original_index} failed with error: {e}")
                return original_index, None, str(e)

    tasks = []
    for original_index, item in enumerate(iterable):
        tasks.append(
            async_function_with_semaphore(async_function, item, original_index)
        )

    ret = [None] * len(tasks)
    for future in tqdm(
        asyncio.as_completed(tasks), total=len(tasks), smoothing=0, desc=desc
    ):
        original_index, result, error = await future
        if error:
            ret[original_index] = None  # set it to some error indicator
        else:
            ret[original_index] = result

    return ret
